{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning a RoBERTa Model and Create a Text Classifier (Sentiment Analysis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The BERT model's attention mechanism is called a Transformer. This is, not coincidentally, the name of the popular BERT Python library, “Transformers,” maintained by a company called HuggingFace. We will use a variant of BERT called [RoBERTa](https://arxiv.org/abs/1907.11692) - a Robustly Optimized BERT Pretraining Approach."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "account_id = boto3.client('sts').get_caller_identity().get('Account')\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retrieve Pre-Processed Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_train_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_train_data_s3_uri)\n",
    "!aws s3 ls $processed_train_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_validation_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_validation_data_s3_uri)\n",
    "!aws s3 ls $processed_validation_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_test_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_test_data_s3_uri)\n",
    "!aws s3 ls $processed_test_data_s3_uri/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify S3 `Distribution Strategy`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.inputs import TrainingInput\n",
    "\n",
    "s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri, \n",
    "                                         distribution='ShardedByS3Key') \n",
    "s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri, \n",
    "                                              distribution='ShardedByS3Key')\n",
    "s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri, \n",
    "                                        distribution='ShardedByS3Key')\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Hyper-Parameters for Classification Layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choosing a `max_seq_length` for RoBERTa\n",
    "Since a smaller `max_seq_length` leads to faster training and lower resource utilization, we want to find the smallest review length that captures `70%` of our reviews.\n",
    "\n",
    "Remember our distribution of review lengths from a previous section?\n",
    "\n",
    "<img src=\"img/review_word_count_distribution.png\" width=\"50%\" align=\"left\">\n",
    "\n",
    "```\n",
    "mean         67.930174\n",
    "std         130.954079\n",
    "min           1.000000\n",
    "10%           4.000000\n",
    "20%          14.000000\n",
    "30%          21.000000\n",
    "40%          25.000000\n",
    "50%          31.000000\n",
    "60%          42.000000\n",
    "70%          59.000000\n",
    "80%          87.000000\n",
    "90%         149.000000\n",
    "100%       5347.000000\n",
    "max        5347.000000\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Review length `59` represents the `70th` percentile for this dataset.  However, it's best to stick with powers-of-2 when using BERT.  So let's choose `64` as this is the smallest power-of-2 greater than `59`.  Reviews with length > `64` will be truncated to `64`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_len=64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='roberta-base'\n",
    "epochs=3\n",
    "lr=2e-5\n",
    "train_batch_size=64\n",
    "train_steps_per_epoch=100\n",
    "validation_batch_size=64\n",
    "test_batch_size=64\n",
    "seed=42\n",
    "backend='gloo'\n",
    "train_instance_count=2\n",
    "train_instance_type='ml.p3.2xlarge'\n",
    "train_volume_size=1024\n",
    "enable_sagemaker_debugger=True\n",
    "input_mode='File'\n",
    "run_validation=True\n",
    "run_test=False\n",
    "run_sample_predictions=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters={\n",
    "        'model_name': model_name,\n",
    "        'epochs': epochs,\n",
    "        'lr': lr,\n",
    "        'train_batch_size': train_batch_size,\n",
    "        'train_steps_per_epoch': train_steps_per_epoch,\n",
    "        'validation_batch_size': validation_batch_size,\n",
    "        'test_batch_size': test_batch_size,\n",
    "        'seed': seed,\n",
    "        'max_seq_len': max_seq_len,\n",
    "        'backend': backend,\n",
    "        'enable_sagemaker_debugger': enable_sagemaker_debugger,\n",
    "        'run_validation': run_validation,\n",
    "        'run_sample_predictions': run_sample_predictions}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Metrics To Track Model Performance\n",
    "\n",
    "These sample log lines...\n",
    "```\n",
    "[step: 0] val_loss: 0.55 - val_acc: 74.64%\n",
    "```\n",
    "\n",
    "...will produce the following 4 metrics in CloudWatch:\n",
    "\n",
    "`val_loss` =  0.55\n",
    "\n",
    "`val_accuracy` = 74.64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/cloudwatch_train_accuracy.png\" width=\"50%\" align=\"left\">\n",
    "\n",
    "<img src=\"img/cloudwatch_train_loss.png\" width=\"50%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_definitions = [\n",
    "     {'Name': 'train:loss', 'Regex': 'train_loss: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'train:accuracy', 'Regex': 'train_acc: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'validation:loss', 'Regex': 'val_loss: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'validation:accuracy', 'Regex': 'val_acc: ([0-9\\\\.]+)'},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup SageMaker Debugger\n",
    "Define Debugger Rules as described here:  https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.debugger import Rule\n",
    "from sagemaker.debugger import rule_configs\n",
    "from sagemaker.debugger import CollectionConfig\n",
    "from sagemaker.debugger import DebuggerHookConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "debugger_hook_config = DebuggerHookConfig(\n",
    "    s3_output_path='s3://{}'.format(bucket),\n",
    "    hook_parameters={\n",
    "        \"save_interval\": \"10\",\n",
    "    },\n",
    "    collection_configs=[\n",
    "        CollectionConfig(\n",
    "            name=\"all\"\n",
    "        )\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Our RoBERTa + PyTorch Script to Run on SageMaker\n",
    "Prepare our PyTorch model to run on the managed SageMaker service"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pygmentize ./src/train_simple.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.pytorch import PyTorch as PyTorchEstimator\n",
    "\n",
    "estimator = PyTorchEstimator(\n",
    "    entry_point='train_simple.py',\n",
    "    source_dir='src',\n",
    "    role=role,\n",
    "    instance_count=train_instance_count,\n",
    "    instance_type=train_instance_type,\n",
    "    volume_size=train_volume_size,\n",
    "    py_version='py3',\n",
    "    framework_version='1.6.0',\n",
    "    hyperparameters=hyperparameters,\n",
    "    metric_definitions=metric_definitions,\n",
    "    input_mode=input_mode,\n",
    "    debugger_hook_config=debugger_hook_config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit(inputs={'train': s3_input_train_data, \n",
    "                      'validation': s3_input_validation_data,\n",
    "                      'test': s3_input_test_data\n",
    "                     },\n",
    "              wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_name = estimator.latest_training_job.name\n",
    "print('Training Job Name:  {}'.format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(bucket, training_job_name, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.latest_training_job.wait(logs=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Wait Until the ^^ Training Job ^^ Completes Above!_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_s3_uri = estimator.model_data\n",
    "print(model_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./tmp/model/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz ./tmp/model/model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -xvzf ./tmp/model/model.tar.gz -C ./tmp/model/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pass Variables to the Next Notebook(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store model_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store training_job_debugger_artifacts_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Release Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "\n",
    "try {\n",
    "    Jupyter.notebook.save_checkpoint();\n",
    "    Jupyter.notebook.session.delete();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
